# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock

from sqlalchemy import column, literal_column, select, table
from sqlalchemy.dialects import mssql, oracle, postgresql
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.types import String, UnicodeText

from superset.db_engine_specs import engines
from superset.db_engine_specs.base import (
    BaseEngineSpec,
    builtin_time_grains,
    create_time_grains_tuple,
)
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
from superset.db_engine_specs.hive import HiveEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.oracle import OracleEngineSpec
from superset.db_engine_specs.pinot import PinotEngineSpec
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.models.core import Database
from .base_tests import SupersetTestCase


class DbEngineSpecsTestCase(SupersetTestCase):
    def test_0_progress(self):
        log = """
            17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=compile from=org.apache.hadoop.hive.ql.Driver>
            17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=parse from=org.apache.hadoop.hive.ql.Driver>
        """.split('\n')  # noqa ignore: E501
        self.assertEquals(
            0, HiveEngineSpec.progress(log))

    def test_number_of_jobs_progress(self):
        log = """
            17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
        """.split('\n')
        self.assertEquals(0, HiveEngineSpec.progress(log))

    def test_job_1_launched_progress(self):
        log = """
            17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
            17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
        """.split('\n')
        self.assertEquals(0, HiveEngineSpec.progress(log))

    def test_job_1_launched_stage_1_0_progress(self):
        log = """
            17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
            17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
            17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%,  reduce = 0%
        """.split('\n')  # noqa ignore: E501
        self.assertEquals(0, HiveEngineSpec.progress(log))

    def test_job_1_launched_stage_1_map_40_progress(self):
        log = """
            17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
            17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
            17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%,  reduce = 0%
            17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%,  reduce = 0%
        """.split('\n')  # noqa ignore: E501
        self.assertEquals(10, HiveEngineSpec.progress(log))

    def test_job_1_launched_stage_1_map_80_reduce_40_progress(self):
        log = """
            17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
            17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
            17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%,  reduce = 0%
            17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%,  reduce = 0%
            17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%,  reduce = 40%
        """.split('\n')  # noqa ignore: E501
        self.assertEquals(30, HiveEngineSpec.progress(log))

    def test_job_1_launched_stage_2_stages_progress(self):
        log = """
            17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
            17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
            17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%,  reduce = 0%
            17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%,  reduce = 0%
            17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 80%,  reduce = 40%
            17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-2 map = 0%,  reduce = 0%
            17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%,  reduce = 0%
        """.split('\n')  # noqa ignore: E501
        self.assertEquals(12, HiveEngineSpec.progress(log))

    def test_job_2_launched_stage_2_stages_progress(self):
        log = """
            17/02/07 19:15:55 INFO ql.Driver: Total jobs = 2
            17/02/07 19:15:55 INFO ql.Driver: Launching Job 1 out of 2
            17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 100%,  reduce = 0%
            17/02/07 19:15:55 INFO ql.Driver: Launching Job 2 out of 2
            17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 0%,  reduce = 0%
            17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%,  reduce = 0%
        """.split('\n')  # noqa ignore: E501
        self.assertEquals(60, HiveEngineSpec.progress(log))

    def test_hive_error_msg(self):
        msg = (
            '{...} errorMessage="Error while compiling statement: FAILED: '
            'SemanticException [Error 10001]: Line 4'
            ':5 Table not found \'fact_ridesfdslakj\'", statusCode=3, '
            "sqlState='42S02', errorCode=10001)){...}")
        self.assertEquals((
            'Error while compiling statement: FAILED: '
            'SemanticException [Error 10001]: Line 4:5 '
            "Table not found 'fact_ridesfdslakj'"),
            HiveEngineSpec.extract_error_message(Exception(msg)))

        e = Exception("Some string that doesn't match the regex")
        self.assertEquals(
            str(e), HiveEngineSpec.extract_error_message(e))

        msg = (
            'errorCode=10001, '
            'errorMessage="Error while compiling statement"), operationHandle'
            '=None)"'
        )
        self.assertEquals((
            'Error while compiling statement'),
            HiveEngineSpec.extract_error_message(Exception(msg)))

    def get_generic_database(self):
        return Database(sqlalchemy_uri='mysql://localhost')

    def sql_limit_regex(
            self, sql, expected_sql,
            engine_spec_class=MySQLEngineSpec,
            limit=1000):
        main = self.get_generic_database()
        limited = engine_spec_class.apply_limit_to_sql(sql, limit, main)
        self.assertEquals(expected_sql, limited)

    def test_extract_limit_from_query(self, engine_spec_class=MySQLEngineSpec):
        q0 = 'select * from table'
        q1 = 'select * from mytable limit 10'
        q2 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20'
        q3 = 'select * from (select * from my_subquery limit 10);'
        q4 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20;'
        q5 = 'select * from mytable limit 10, 20'
        q6 = 'select * from mytable limit 10 offset 20'
        q7 = 'select * from mytable limit'
        q8 = 'select * from mytable limit 10.0'
        q9 = 'select * from mytable limit x'
        q10 = 'select * from mytable limit x, 20'
        q11 = 'select * from mytable limit x offset 20'

        self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
        self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
        self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
        self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
        self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
        self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
        self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
        self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
        self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
        self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
        self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
        self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)

    def test_wrapped_query(self):
        self.sql_limit_regex(
            'SELECT * FROM a',
            'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000',
            MssqlEngineSpec,
        )

    def test_wrapped_semi(self):
        self.sql_limit_regex(
            'SELECT * FROM a;',
            'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000',
            MssqlEngineSpec,
        )

    def test_wrapped_semi_tabs(self):
        self.sql_limit_regex(
            'SELECT * FROM a  \t \n   ; \t  \n  ',
            'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000',
            MssqlEngineSpec,
        )

    def test_simple_limit_query(self):
        self.sql_limit_regex(
            'SELECT * FROM a',
            'SELECT * FROM a\nLIMIT 1000',
        )

    def test_modify_limit_query(self):
        self.sql_limit_regex(
            'SELECT * FROM a LIMIT 9999',
            'SELECT * FROM a LIMIT 1000',
        )

    def test_limit_query_with_limit_subquery(self):
        self.sql_limit_regex(
            'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999',
            'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000',
        )

    def test_limit_with_expr(self):
        self.sql_limit_regex(
            """
            SELECT
                'LIMIT 777' AS a
                , b
            FROM
            table
            LIMIT 99990""",
            """SELECT
                'LIMIT 777' AS a
                , b
            FROM
            table
            LIMIT 1000""",
        )

    def test_limit_expr_and_semicolon(self):
        self.sql_limit_regex(
            """
                SELECT
                    'LIMIT 777' AS a
                    , b
                FROM
                table
                LIMIT         99990            ;""",
            """SELECT
                    'LIMIT 777' AS a
                    , b
                FROM
                table
                LIMIT         1000""",
        )

    def test_get_datatype(self):
        self.assertEquals('STRING', PrestoEngineSpec.get_datatype('string'))
        self.assertEquals('TINY', MySQLEngineSpec.get_datatype(1))
        self.assertEquals('VARCHAR', MySQLEngineSpec.get_datatype(15))
        self.assertEquals('VARCHAR', BaseEngineSpec.get_datatype('VARCHAR'))

    def test_limit_with_implicit_offset(self):
        self.sql_limit_regex(
            """
                SELECT
                    'LIMIT 777' AS a
                    , b
                FROM
                table
                LIMIT 99990, 999999""",
            """SELECT
                    'LIMIT 777' AS a
                    , b
                FROM
                table
                LIMIT 99990, 1000""",
        )

    def test_limit_with_explicit_offset(self):
        self.sql_limit_regex(
            """
                SELECT
                    'LIMIT 777' AS a
                    , b
                FROM
                table
                LIMIT 99990
                OFFSET 999999""",
            """SELECT
                    'LIMIT 777' AS a
                    , b
                FROM
                table
                LIMIT 1000
                OFFSET 999999""",
        )

    def test_limit_with_non_token_limit(self):
        self.sql_limit_regex(
            """
                SELECT
                    'LIMIT 777'""",
            """
                SELECT
                    'LIMIT 777'\nLIMIT 1000""",
        )

    def test_time_grain_blacklist(self):
        blacklist = ['PT1M']
        time_grains = {
            'PT1S': 'second',
            'PT1M': 'minute',
        }
        time_grain_functions = {
            'PT1S': '{col}',
            'PT1M': '{col}',
        }
        time_grains = create_time_grains_tuple(time_grains,
                                               time_grain_functions,
                                               blacklist)
        self.assertEqual(1, len(time_grains))
        self.assertEqual('PT1S', time_grains[0].duration)

    def test_engine_time_grain_validity(self):
        time_grains = set(builtin_time_grains.keys())
        # loop over all subclasses of BaseEngineSpec
        for engine in engines.values():
            if engine is not BaseEngineSpec:
                # make sure time grain functions have been defined
                self.assertGreater(len(engine.time_grain_functions), 0)
                # make sure that all defined time grains are supported
                defined_grains = {grain.duration for grain in engine.get_time_grains()}
                intersection = time_grains.intersection(defined_grains)
                self.assertSetEqual(defined_grains, intersection, engine)

    def test_presto_get_view_names_return_empty_list(self):
        self.assertEquals([], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY))

    def verify_presto_column(self, column, expected_results):
        inspector = mock.Mock()
        inspector.engine.dialect.identifier_preparer.quote_identifier = mock.Mock()
        keymap = {'Column': (None, None, 0),
                  'Type': (None, None, 1),
                  'Null': (None, None, 2)}
        row = RowProxy(mock.Mock(), column, [None, None, None, None], keymap)
        inspector.bind.execute = mock.Mock(return_value=[row])
        results = PrestoEngineSpec.get_columns(inspector, '', '')
        self.assertEqual(len(expected_results), len(results))
        for expected_result, result in zip(expected_results, results):
            self.assertEqual(expected_result[0], result['name'])
            self.assertEqual(expected_result[1], str(result['type']))

    def test_presto_get_column(self):
        presto_column = ('column_name', 'boolean', '')
        expected_results = [('column_name', 'BOOLEAN')]
        self.verify_presto_column(presto_column, expected_results)

    def test_presto_get_simple_row_column(self):
        presto_column = ('column_name', 'row(nested_obj double)', '')
        expected_results = [
            ('column_name', 'ROW'),
            ('column_name.nested_obj', 'FLOAT')]
        self.verify_presto_column(presto_column, expected_results)

    def test_presto_get_simple_row_column_with_name_containing_whitespace(self):
        presto_column = ('column name', 'row(nested_obj double)', '')
        expected_results = [
            ('column name', 'ROW'),
            ('column name.nested_obj', 'FLOAT')]
        self.verify_presto_column(presto_column, expected_results)

    def test_presto_get_simple_row_column_with_tricky_nested_field_name(self):
        presto_column = ('column_name', 'row("Field Name(Tricky, Name)" double)', '')
        expected_results = [
            ('column_name', 'ROW'),
            ('column_name."Field Name(Tricky, Name)"', 'FLOAT')]
        self.verify_presto_column(presto_column, expected_results)

    def test_presto_get_simple_array_column(self):
        presto_column = ('column_name', 'array(double)', '')
        expected_results = [('column_name', 'ARRAY')]
        self.verify_presto_column(presto_column, expected_results)

    def test_presto_get_row_within_array_within_row_column(self):
        presto_column = (
            'column_name',
            'row(nested_array array(row(nested_row double)), nested_obj double)', '')
        expected_results = [
            ('column_name', 'ROW'),
            ('column_name.nested_array', 'ARRAY'),
            ('column_name.nested_array.nested_row', 'FLOAT'),
            ('column_name.nested_obj', 'FLOAT'),
        ]
        self.verify_presto_column(presto_column, expected_results)

    def test_presto_get_array_within_row_within_array_column(self):
        presto_column = (
            'column_name',
            'array(row(nested_array array(double), nested_obj double))', '')
        expected_results = [
            ('column_name', 'ARRAY'),
            ('column_name.nested_array', 'ARRAY'),
            ('column_name.nested_obj', 'FLOAT')]
        self.verify_presto_column(presto_column, expected_results)

    def test_presto_get_fields(self):
        cols = [
            {'name': 'column'},
            {'name': 'column.nested_obj'},
            {'name': 'column."quoted.nested obj"'}]
        actual_results = PrestoEngineSpec._get_fields(cols)
        expected_results = [
            {'name': '"column"', 'label': 'column'},
            {'name': '"column"."nested_obj"', 'label': 'column.nested_obj'},
            {'name': '"column"."quoted.nested obj"',
             'label': 'column."quoted.nested obj"'}]
        for actual_result, expected_result in zip(actual_results, expected_results):
            self.assertEqual(actual_result.element.name, expected_result['name'])
            self.assertEqual(actual_result.name, expected_result['label'])

    def test_presto_filter_out_array_nested_cols(self):
        cols = [
            {'name': 'column', 'type': 'ARRAY'},
            {'name': 'column.nested_obj', 'type': 'FLOAT'}]
        actual_filtered_cols,\
            actual_array_cols = PrestoEngineSpec._filter_out_array_nested_cols(cols)
        expected_filtered_cols = [{'name': 'column', 'type': 'ARRAY'}]
        self.assertEqual(actual_filtered_cols, expected_filtered_cols)
        self.assertEqual(actual_array_cols, cols)

    def test_presto_create_row_and_array_hierarchy(self):
        cols = [
            {'name': 'row_column',
             'type': 'ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR)'},
            {'name': 'array_column',
             'type': 'ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))'}]
        actual_row_col_hierarchy,\
            actual_array_col_hierarchy,\
            actual_expanded_cols = PrestoEngineSpec._create_row_and_array_hierarchy(cols)
        expected_row_col_hierarchy = {
            'row_column': {
                'type': 'ROW',
                'children': ['row_column.nested_obj1', 'row_column.nested_row'],
            },
            'row_column.nested_row': {
                'type': 'ROW',
                'children': ['row_column.nested_row.nested_obj2']},
        }
        expected_array_col_hierarchy = {
            'array_column': {
                'type': 'ARRAY',
                'children': ['array_column.nested_array'],
            },
            'array_column.nested_array': {
                'type': 'ARRAY',
                'children': ['array_column.nested_array.nested_obj']},
        }
        expected_expanded_cols = [
            {'name': 'row_column.nested_obj1', 'type': 'VARCHAR'},
            {'name': 'row_column.nested_row', 'type': 'ROW'},
            {'name': 'row_column.nested_row.nested_obj2', 'type': 'VARCHAR'},
            {'name': 'array_column.nested_array', 'type': 'ARRAY'},
            {'name': 'array_column.nested_array.nested_obj', 'type': 'VARCHAR'}]
        self.assertEqual(actual_row_col_hierarchy, expected_row_col_hierarchy)
        self.assertEqual(actual_array_col_hierarchy, expected_array_col_hierarchy)
        self.assertEqual(actual_expanded_cols, expected_expanded_cols)

    def test_presto_expand_row_data(self):
        datum = {'row_col': [1, 'a']}
        row_column = 'row_col'
        row_col_hierarchy = {
            'row_col': {
                'type': 'ROW',
                'children': ['row_col.nested_int', 'row_col.nested_str'],
            },
        }
        PrestoEngineSpec._expand_row_data(datum, row_column, row_col_hierarchy)
        expected_datum = {
            'row_col': [1, 'a'], 'row_col.nested_int': 1, 'row_col.nested_str': 'a',
        }
        self.assertEqual(datum, expected_datum)

    def test_split_array_columns_by_process_state(self):
        array_cols = ['array_column', 'array_column.nested_array']
        array_col_hierarchy = {
            'array_column': {
                'type': 'ARRAY',
                'children': ['array_column.nested_array'],
            },
            'array_column.nested_array': {
                'type': 'ARRAY',
                'children': ['array_column.nested_array.nested_obj']},
        }
        datum = {'array_column': [[[1], [2]]]}
        actual_array_cols_to_process, actual_unprocessed_array_cols = \
            PrestoEngineSpec._split_array_columns_by_process_state(
                array_cols, array_col_hierarchy, datum)
        expected_array_cols_to_process = ['array_column']
        expected_unprocessed_array_cols = {'array_column.nested_array'}
        self.assertEqual(actual_array_cols_to_process, expected_array_cols_to_process)
        self.assertEqual(actual_unprocessed_array_cols, expected_unprocessed_array_cols)

    def test_presto_convert_data_list_to_array_data_dict(self):
        data = [
            {'array_column': [1, 2], 'int_column': 3},
            {'array_column': [11, 22], 'int_column': 33},
        ]
        array_columns_to_process = ['array_column']
        actual_array_data_dict = PrestoEngineSpec._convert_data_list_to_array_data_dict(
            data, array_columns_to_process)
        expected_array_data_dict = {
            0: [{'array_column': [1, 2]}],
            1: [{'array_column': [11, 22]}]}
        self.assertEqual(actual_array_data_dict, expected_array_data_dict)

    def test_presto_process_array_data(self):
        data = [
            {'array_column': [[1], [2]], 'int_column': 3},
            {'array_column': [[11], [22]], 'int_column': 33},
        ]
        all_columns = [
            {'name': 'array_column', 'type': 'ARRAY'},
            {'name': 'array_column.nested_row', 'type': 'BIGINT'},
            {'name': 'int_column', 'type': 'BIGINT'},
        ]
        array_column_hierarchy = {
            'array_column': {
                'type': 'ARRAY',
                'children': ['array_column.nested_row'],
            },
        }
        actual_array_data = PrestoEngineSpec._process_array_data(
            data, all_columns, array_column_hierarchy)
        expected_array_data = {
            0: [
                {'array_column': [[1], [2]], 'array_column.nested_row': 1},
                {'array_column': '', 'array_column.nested_row': 2, 'int_column': ''},
            ],
            1: [
                {'array_column': [[11], [22]], 'array_column.nested_row': 11},
                {'array_column': '', 'array_column.nested_row': 22, 'int_column': ''},
            ],
        }
        self.assertEqual(actual_array_data, expected_array_data)

    def test_presto_consolidate_array_data_into_data(self):
        data = [
            {'arr_col': [[1], [2]], 'int_col': 3},
            {'arr_col': [[11], [22]], 'int_col': 33},
        ]
        array_data = {
            0: [
                {'arr_col': [[1], [2]], 'arr_col.nested_row': 1},
                {'arr_col': '', 'arr_col.nested_row': 2, 'int_col': ''},
            ],
            1: [
                {'arr_col': [[11], [22]], 'arr_col.nested_row': 11},
                {'arr_col': '', 'arr_col.nested_row': 22, 'int_col': ''},
            ],
        }
        PrestoEngineSpec._consolidate_array_data_into_data(data, array_data)
        expected_data = [
            {'arr_col': [[1], [2]], 'arr_col.nested_row': 1, 'int_col': 3},
            {'arr_col': '', 'arr_col.nested_row': 2, 'int_col': ''},
            {'arr_col': [[11], [22]], 'arr_col.nested_row': 11, 'int_col': 33},
            {'arr_col': '', 'arr_col.nested_row': 22, 'int_col': ''},
        ]
        self.assertEqual(data, expected_data)

    def test_presto_remove_processed_array_columns(self):
        array_col_hierarchy = {
            'array_column': {
                'type': 'ARRAY',
                'children': ['array_column.nested_array'],
            },
            'array_column.nested_array': {
                'type': 'ARRAY',
                'children': ['array_column.nested_array.nested_obj']},
        }
        unprocessed_array_cols = {'array_column.nested_array'}
        PrestoEngineSpec._remove_processed_array_columns(
            unprocessed_array_cols, array_col_hierarchy)
        expected_array_col_hierarchy = {
            'array_column.nested_array': {
                'type': 'ARRAY',
                'children': ['array_column.nested_array.nested_obj']},
        }
        self.assertEqual(array_col_hierarchy, expected_array_col_hierarchy)

    def test_presto_expand_data_with_simple_structural_columns(self):
        cols = [
            {'name': 'row_column', 'type': 'ROW(NESTED_OBJ VARCHAR)'},
            {'name': 'array_column', 'type': 'ARRAY(BIGINT)'}]
        data = [
            {'row_column': ['a'], 'array_column': [1, 2, 3]},
            {'row_column': ['b'], 'array_column': [4, 5, 6]}]
        actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
            cols, data)
        expected_cols = [
            {'name': 'row_column', 'type': 'ROW'},
            {'name': 'row_column.nested_obj', 'type': 'VARCHAR'},
            {'name': 'array_column', 'type': 'ARRAY'}]
        expected_data = [
            {'row_column': ['a'], 'row_column.nested_obj': 'a', 'array_column': 1},
            {'row_column': '', 'row_column.nested_obj': '', 'array_column': 2},
            {'row_column': '', 'row_column.nested_obj': '', 'array_column': 3},
            {'row_column': ['b'], 'row_column.nested_obj': 'b', 'array_column': 4},
            {'row_column': '', 'row_column.nested_obj': '', 'array_column': 5},
            {'row_column': '', 'row_column.nested_obj': '', 'array_column': 6}]
        expected_expanded_cols = [
            {'name': 'row_column.nested_obj', 'type': 'VARCHAR'}]
        self.assertEqual(actual_cols, expected_cols)
        self.assertEqual(actual_data, expected_data)
        self.assertEqual(actual_expanded_cols, expected_expanded_cols)

    def test_presto_expand_data_with_complex_row_columns(self):
        cols = [
            {'name': 'row_column',
             'type': 'ROW(NESTED_OBJ1 VARCHAR, NESTED_ROW ROW(NESTED_OBJ2 VARCHAR)'}]
        data = [
            {'row_column': ['a1', ['a2']]},
            {'row_column': ['b1', ['b2']]}]
        actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
            cols, data)
        expected_cols = [
            {'name': 'row_column', 'type': 'ROW'},
            {'name': 'row_column.nested_obj1', 'type': 'VARCHAR'},
            {'name': 'row_column.nested_row', 'type': 'ROW'},
            {'name': 'row_column.nested_row.nested_obj2', 'type': 'VARCHAR'}]
        expected_data = [
            {'row_column': ['a1', ['a2']],
             'row_column.nested_obj1': 'a1',
             'row_column.nested_row': ['a2'],
             'row_column.nested_row.nested_obj2': 'a2'},
            {'row_column': ['b1', ['b2']],
             'row_column.nested_obj1': 'b1',
             'row_column.nested_row': ['b2'],
             'row_column.nested_row.nested_obj2': 'b2'}]
        expected_expanded_cols = [
            {'name': 'row_column.nested_obj1', 'type': 'VARCHAR'},
            {'name': 'row_column.nested_row', 'type': 'ROW'},
            {'name': 'row_column.nested_row.nested_obj2', 'type': 'VARCHAR'}]
        self.assertEqual(actual_cols, expected_cols)
        self.assertEqual(actual_data, expected_data)
        self.assertEqual(actual_expanded_cols, expected_expanded_cols)

    def test_presto_expand_data_with_complex_array_columns(self):
        cols = [
            {'name': 'int_column', 'type': 'BIGINT'},
            {'name': 'array_column',
             'type': 'ARRAY(ROW(NESTED_ARRAY ARRAY(ROW(NESTED_OBJ VARCHAR))))'}]
        data = [
            {'int_column': 1, 'array_column': [[[['a'], ['b']]], [[['c'], ['d']]]]},
            {'int_column': 2, 'array_column': [[[['e'], ['f']]], [[['g'], ['h']]]]}]
        actual_cols, actual_data, actual_expanded_cols = PrestoEngineSpec.expand_data(
            cols, data)
        expected_cols = [
            {'name': 'int_column', 'type': 'BIGINT'},
            {'name': 'array_column', 'type': 'ARRAY'},
            {'name': 'array_column.nested_array', 'type': 'ARRAY'},
            {'name': 'array_column.nested_array.nested_obj', 'type': 'VARCHAR'}]
        expected_data = [
            {'int_column': 1,
             'array_column': [[[['a'], ['b']]], [[['c'], ['d']]]],
             'array_column.nested_array': [['a'], ['b']],
             'array_column.nested_array.nested_obj': 'a'},
            {'int_column': '',
             'array_column': '',
             'array_column.nested_array': '',
             'array_column.nested_array.nested_obj': 'b'},
            {'int_column': '',
             'array_column': '',
             'array_column.nested_array': [['c'], ['d']],
             'array_column.nested_array.nested_obj': 'c'},
            {'int_column': '',
             'array_column': '',
             'array_column.nested_array': '',
             'array_column.nested_array.nested_obj': 'd'},
            {'int_column': 2,
             'array_column': [[[['e'], ['f']]], [[['g'], ['h']]]],
             'array_column.nested_array': [['e'], ['f']],
             'array_column.nested_array.nested_obj': 'e'},
            {'int_column': '',
             'array_column': '',
             'array_column.nested_array': '',
             'array_column.nested_array.nested_obj': 'f'},
            {'int_column': '',
             'array_column': '',
             'array_column.nested_array': [['g'], ['h']],
             'array_column.nested_array.nested_obj': 'g'},
            {'int_column': '',
             'array_column': '',
             'array_column.nested_array': '',
             'array_column.nested_array.nested_obj': 'h'}]
        expected_expanded_cols = [
            {'name': 'array_column.nested_array', 'type': 'ARRAY'},
            {'name': 'array_column.nested_array.nested_obj', 'type': 'VARCHAR'}]
        self.assertEqual(actual_cols, expected_cols)
        self.assertEqual(actual_data, expected_data)
        self.assertEqual(actual_expanded_cols, expected_expanded_cols)

    def test_hive_get_view_names_return_empty_list(self):
        self.assertEquals([], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY))

    def test_bigquery_sqla_column_label(self):
        label = BigQueryEngineSpec.make_label_compatible(column('Col').name)
        label_expected = 'Col'
        self.assertEqual(label, label_expected)

        label = BigQueryEngineSpec.make_label_compatible(column('SUM(x)').name)
        label_expected = 'SUM_x__5f110b965a993675bc4953bb3e03c4a5'
        self.assertEqual(label, label_expected)

        label = BigQueryEngineSpec.make_label_compatible(column('SUM[x]').name)
        label_expected = 'SUM_x__7ebe14a3f9534aeee125449b0bc083a8'
        self.assertEqual(label, label_expected)

        label = BigQueryEngineSpec.make_label_compatible(column('12345_col').name)
        label_expected = '_12345_col_8d3906e2ea99332eb185f7f8ecb2ffd6'
        self.assertEqual(label, label_expected)

    def test_oracle_sqla_column_name_length_exceeded(self):
        col = column('This_Is_32_Character_Column_Name')
        label = OracleEngineSpec.make_label_compatible(col.name)
        self.assertEqual(label.quote, True)
        label_expected = '3b26974078683be078219674eeb8f5'
        self.assertEqual(label, label_expected)

    def test_mssql_column_types(self):
        def assert_type(type_string, type_expected):
            type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
            if type_expected is None:
                self.assertIsNone(type_assigned)
            else:
                self.assertIsInstance(type_assigned, type_expected)

        assert_type('INT', None)
        assert_type('STRING', String)
        assert_type('CHAR(10)', String)
        assert_type('VARCHAR(10)', String)
        assert_type('TEXT', String)
        assert_type('NCHAR(10)', UnicodeText)
        assert_type('NVARCHAR(10)', UnicodeText)
        assert_type('NTEXT', UnicodeText)

    def test_mssql_where_clause_n_prefix(self):
        dialect = mssql.dialect()
        spec = MssqlEngineSpec
        str_col = column('col', type_=spec.get_sqla_column_type('VARCHAR(10)'))
        unicode_col = column('unicode_col', type_=spec.get_sqla_column_type('NTEXT'))
        tbl = table('tbl')
        sel = select([str_col, unicode_col]).\
            select_from(tbl).\
            where(str_col == 'abc').\
            where(unicode_col == 'abc')

        query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True}))
        query_expected = 'SELECT col, unicode_col \n' \
                         'FROM tbl \n' \
                         "WHERE col = 'abc' AND unicode_col = N'abc'"
        self.assertEqual(query, query_expected)

    def test_get_table_names(self):
        inspector = mock.Mock()
        inspector.get_table_names = mock.Mock(return_value=['schema.table', 'table_2'])
        inspector.get_foreign_table_names = mock.Mock(return_value=['table_3'])

        """ Make sure base engine spec removes schema name from table name
        ie. when try_remove_schema_from_table_name == True. """
        base_result_expected = ['table', 'table_2']
        base_result = BaseEngineSpec.get_table_names(
            schema='schema', inspector=inspector)
        self.assertListEqual(base_result_expected, base_result)

        """ Make sure postgres doesn't try to remove schema name from table name
        ie. when try_remove_schema_from_table_name == False. """
        pg_result_expected = ['schema.table', 'table_2', 'table_3']
        pg_result = PostgresEngineSpec.get_table_names(
            schema='schema', inspector=inspector)
        self.assertListEqual(pg_result_expected, pg_result)

    def test_pg_time_expression_literal_no_grain(self):
        col = literal_column('COALESCE(a, b)')
        expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
        result = str(expr.compile(dialect=postgresql.dialect()))
        self.assertEqual(result, 'COALESCE(a, b)')

    def test_pg_time_expression_literal_1y_grain(self):
        col = literal_column('COALESCE(a, b)')
        expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y')
        result = str(expr.compile(dialect=postgresql.dialect()))
        self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))")

    def test_pg_time_expression_lower_column_no_grain(self):
        col = column('lower_case')
        expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
        result = str(expr.compile(dialect=postgresql.dialect()))
        self.assertEqual(result, 'lower_case')

    def test_pg_time_expression_lower_case_column_sec_1y_grain(self):
        col = column('lower_case')
        expr = PostgresEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1Y')
        result = str(expr.compile(dialect=postgresql.dialect()))
        self.assertEqual(result, "DATE_TRUNC('year', (timestamp 'epoch' + lower_case * interval '1 second'))")  # noqa

    def test_pg_time_expression_mixed_case_column_1y_grain(self):
        col = column('MixedCase')
        expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y')
        result = str(expr.compile(dialect=postgresql.dialect()))
        self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")

    def test_mssql_time_expression_mixed_case_column_1y_grain(self):
        col = column('MixedCase')
        expr = MssqlEngineSpec.get_timestamp_expr(col, None, 'P1Y')
        result = str(expr.compile(dialect=mssql.dialect()))
        self.assertEqual(result, 'DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)')

    def test_oracle_time_expression_reserved_keyword_1m_grain(self):
        col = column('decimal')
        expr = OracleEngineSpec.get_timestamp_expr(col, None, 'P1M')
        result = str(expr.compile(dialect=oracle.dialect()))
        self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')")

    def test_pinot_time_expression_sec_1m_grain(self):
        col = column('tstamp')
        expr = PinotEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1M')
        result = str(expr.compile())
        self.assertEqual(result, 'DATETIMECONVERT(tstamp, "1:SECONDS:EPOCH", "1:SECONDS:EPOCH", "1:MONTHS")')  # noqa

    def test_column_datatype_to_string(self):
        main_db = self.get_main_database()
        sqla_table = main_db.get_table('energy_usage')
        dialect = main_db.get_dialect()
        col_names = [
            main_db.db_engine_spec.column_datatype_to_string(c.type, dialect)
            for c in sqla_table.columns
        ]
        if main_db.backend == 'postgresql':
            expected = ['VARCHAR(255)', 'VARCHAR(255)', 'DOUBLE PRECISION']
        else:
            expected = ['VARCHAR(255)', 'VARCHAR(255)', 'FLOAT']
        self.assertEquals(col_names, expected)
